import os
import json
import random

import cv2
import numpy as np


def load_json(path):
    with open(path) as f:
        data = json.load(f)
    return data


class OccludedHumanPartAugmentation:
    LEVELS = ["l1", "l2", "l3", "full_human"]

    def __init__(self, dataset_dir, flip_rate=0.5, probs=(0, 1, 0, 0), img_shape=(256, 256)):
        self.dataset_dir = dataset_dir
        self.db = self._load_dataset(dataset_dir, probs)
        self.probs = probs

        self.img_shape = img_shape
        self.flip_rate = flip_rate

    def _load_dataset(self, dataset_dir, probs):
        # each sample is a tuple of (img_path, scales)
        filenames_with_scales = []
        for i in range(4):
            if probs[i] == 0:
                filenames_with_scales.append([])
                continue

            annts = load_json(os.path.join(dataset_dir, f"labels_{self.LEVELS[i]}.json"))
            files_with_scale = [(file, annts[file]["scale"]) for file in
                                os.listdir(os.path.join(dataset_dir, "images", self.LEVELS[i]))]
            filenames_with_scales.append(files_with_scale)
        return filenames_with_scales

    def _sample_img_path(self):
        sample_level = np.random.choice([0, 1, 2, 3], p=self.probs)
        sample_filename, scale = random.sample(self.db[sample_level], 1)[0]
        return os.path.join(self.dataset_dir, "images", self.LEVELS[sample_level], sample_filename), scale

    def _rescale(self, img, scale):
        scale_x = scale[0] * self.img_shape[0] / img.shape[1]
        scale_y = scale[1] * self.img_shape[1] / img.shape[0]
        scale = min(scale_x, scale_y)
        img = cv2.resize(img, (0, 0), fx=scale, fy=scale)
        return img

    def _sample_img(self):
        png_path, scale = self._sample_img_path()
        png = cv2.imread(png_path, cv2.IMREAD_UNCHANGED)
        png[png[..., 3] == 0, :3] = 0
        png = self._rescale(png, scale)
        return png

    def _sample_affine_transform(self):
        s = random.uniform(0.75, 1.25)
        r = random.uniform(-30, 30)
        tx = int(random.uniform(0, self.img_shape[0]))
        ty = int(random.uniform(0, self.img_shape[1]))
        flip = random.random() < self.flip_rate
        return np.array([tx, ty]), s, r, flip

    def _transform(self, img, c, s, r, flip):
        if flip:
            img = cv2.flip(img, 1)
        img_shape = img.shape[:2][::-1]
        rotate_center = (img_shape[0] / 2, img_shape[1] / 2)
        affine_matrix = cv2.getRotationMatrix2D(rotate_center, r, s)
        affine_matrix[:, 2] += c - np.array(rotate_center)
        img_trans = cv2.warpAffine(img, affine_matrix, self.img_shape)
        return img_trans

    def __call__(self, img):
        # input img should be in BGR format
        # TODO: update joint vis annotation if texture mapping occlude human keypoints
        png = self._sample_img()
        c, s, r, flip = self._sample_affine_transform()
        png_trans = self._transform(png, c, s, r, flip)
        png_trans, mask_trans = png_trans[..., :3], np.expand_dims((png_trans[..., 3] > 0).astype(np.float), 2)
        output = (mask_trans * png_trans + (1 - mask_trans) * img).astype(np.uint8)
        return output
